{
  "nbformat": 4,
  "nbformat_minor": 0,
  "metadata": {
    "colab": {
      "name": "Trax_Translation.ipynb",
      "provenance": [],
      "collapsed_sections": []
    },
    "kernelspec": {
      "name": "python3",
      "display_name": "Python 3"
    },
    "accelerator": "GPU"
  },
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "1liQji85FAIp"
      },
      "source": [
        "#Machine Translation with Trax\n",
        "\n",
        "Note by Denis Rothman: The original notebook was split into cells.\n",
        "\n",
        "[Reference Code](https://colab.research.google.com/github/google/trax/blob/master/trax/intro.ipynb)\n"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "h0pjcihTE9fR"
      },
      "source": [
        "#@title Installing Trax\n",
        "import os\n",
        "import numpy as np\n",
        "\n",
        "!pip install -q -U trax\n",
        "import trax"
      ],
      "execution_count": 7,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "ivTjrL-BMD8i"
      },
      "source": [
        "#@title Creating\n",
        "# Pre-trained model config in gs://trax-ml/models/translation/ende_wmt32k.gin\n",
        "model = trax.models.Transformer(\n",
        "    input_vocab_size=33300,\n",
        "    d_model=512, d_ff=2048,\n",
        "    n_heads=8, n_encoder_layers=6, n_decoder_layers=6,\n",
        "    max_len=2048, mode='predict')\n"
      ],
      "execution_count": 8,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "oJgRqlrmMKbo"
      },
      "source": [
        "#@title Initializing the model using pre-trained weights\n",
        "model.init_from_file('gs://trax-ml/models/translation/ende_wmt32k.pkl.gz',\n",
        "                     weights_only=True)"
      ],
      "execution_count": 9,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "HvwJ5w-6MQNw"
      },
      "source": [
        "#@title Tokenizing a sentence\n",
        "sentence = 'I am only a machine but I have machine intelligence.'\n",
        "\n",
        "tokenized = list(trax.data.tokenize(iter([sentence]),  # Operates on streams.\n",
        "                                    vocab_dir='gs://trax-ml/vocabs/',\n",
        "                                    vocab_file='ende_32k.subword'))[0]\n"
      ],
      "execution_count": 10,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "IVkBQOvmMW9A"
      },
      "source": [
        "#@title Decoding from the Transformer\n",
        "tokenized = tokenized[None, :]  # Add batch dimension.\n",
        "tokenized_translation = trax.supervised.decoding.autoregressive_sample(\n",
        "    model, tokenized, temperature=0.0)  # Higher temperature: more diverse results.\n"
      ],
      "execution_count": 11,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "QV2xr8_7Mc4B",
        "outputId": "c78c12ea-84a1-4fd5-fb2e-770fadc19e8b"
      },
      "source": [
        "#@title De-tokenizing and Displaying the Translation\n",
        "tokenized_translation = tokenized_translation[0][:-1]  # Remove batch and EOS.\n",
        "translation = trax.data.detokenize(tokenized_translation,\n",
        "                                   vocab_dir='gs://trax-ml/vocabs/',\n",
        "                                   vocab_file='ende_32k.subword')\n",
        "print(\"The sentence:\",sentence)\n",
        "print(\"The translation:\",translation)"
      ],
      "execution_count": 12,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "The sentence: I am only a machine but I have machine intelligence.\n",
            "The translation: Ich bin nur eine Maschine, aber ich habe Maschinenübersicht.\n"
          ],
          "name": "stdout"
        }
      ]
    }
  ]
}